package com.campus.counseling.service;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

public interface LSTMModelTrainer {
    /**
     * 初始化模型
     */
    void initModel();
    
    /**
     * 训练模型
     */
    void train(DataSet trainingData);
    
    /**
     * 预测
     * @param features 输入特征
     * @return 预测结果数组
     */
    INDArray predict(INDArray features);
    
    /**
     * 保存模型
     */
    void saveModel();
    
    /**
     * 加载模型
     */
    MultiLayerNetwork loadModel();
} 